{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'\\n作者：英俊\\nQQ:2227495940\\n邮箱 同上\\n'"
      ]
     },
     "execution_count": 60,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "'''\n",
    "作者：英俊\n",
    "QQ:2227495940\n",
    "邮箱 同上\n",
    "'''"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 解决方案"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* 1 机器学习方案\n",
    "\n",
    "通过机器学习为手段进行文本分类\n",
    "\n",
    "* 2 深度学习方案\n",
    "\n",
    "通过深度学习为手段进行文本分类\n",
    "\n",
    "* 3 fastnlp方案的应用\n",
    "\n",
    "这种模式需要先将现成的数据集读取并且以csv/txt格式导出然后读取"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 进行数据预处理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import re\n",
    "import pandas as pd\n",
    "\n",
    "\n",
    "# clean useless characters\n",
    "'''\n",
    "html_clean = ['& ldquo ;', '& hellip ;', '& rdquo ;', '& yen ;']\n",
    "punctuation_replace = '[，。！？]+'\n",
    "strange_num = ['①','②','③','④']\n",
    "'''\n",
    "punctuation_remove = '[：；……（）『』《》【】～!\"#$%&\\'()*+,-./:;<=>?@[\\\\]^_`{|}~]+'\n",
    "\n",
    "def clean(sent):\n",
    "    sent = re.sub(r'ldquo', \"\", sent)\n",
    "    sent = re.sub(r'hellip', \"\", sent)\n",
    "    sent = re.sub(r'rdquo', \"\", sent)\n",
    "    sent = re.sub(r'yen', \"\", sent)\n",
    "    sent = re.sub(r'⑦', \"7\", sent)\n",
    "    sent = re.sub(r'(， ){2,}', \"\", sent)\n",
    "    sent = re.sub(r'(！ ){2,}', \"\", sent) # delete too many！，？，。等\n",
    "    sent = re.sub(r'(？ ){2,}', \"\", sent)\n",
    "    sent = re.sub(r'(。 ){2,}', \"\", sent)\n",
    "    sent = re.sub(punctuation_remove, \"\", sent) #delete punctuations\n",
    "    s = ' '.join(sent.split()) #delete additional space\n",
    "    \n",
    "    return s\n",
    "    \n",
    "def sent_filter(l):\n",
    "    l_new = []\n",
    "    for s,k in enumerate(l):\n",
    "        if len(k) > 2:\n",
    "            l_new.append(k)\n",
    "    return l_new\n",
    "\n",
    "# 这里是深度学习模式下的读取数据集\n",
    "def dl_load_data_and_labels(good_data_file, bad_data_file, mid_data_file):\n",
    "    #load reviews and save them in the list\n",
    "    good_examples = list(open(good_data_file, \"r\", encoding='utf-8').readlines())\n",
    "    good_examples = [s.strip() for s in good_examples]\n",
    "    bad_examples = list(open(bad_data_file, \"r\", encoding='utf-8').readlines())\n",
    "    bad_examples = [s.strip() for s in bad_examples]\n",
    "    mid_examples = list(open(mid_data_file, \"r\", encoding='utf-8').readlines())\n",
    "    mid_examples = [s.strip() for s in mid_examples]\n",
    "\n",
    "    #Call the clean () and sent_filter () functions to process the comments, save them in the x_text list\n",
    "    good_examples = [clean(sent) for sent in good_examples]\n",
    "    bad_examples = [clean(sent) for sent in bad_examples]\n",
    "    mid_examples = [clean(sent) for sent in mid_examples]\n",
    "\n",
    "    good_examples = [i.strip() for i in good_examples]\n",
    "    bad_examples = [i.strip() for i in bad_examples]\n",
    "    mid_examples = [i.strip() for i in mid_examples]\n",
    "\n",
    "    good_examples = sent_filter(good_examples)\n",
    "    bad_examples = sent_filter(bad_examples)\n",
    "    mid_examples = sent_filter(mid_examples)\n",
    "    x_text = good_examples + bad_examples + mid_examples\n",
    "\n",
    "    #Add a label for each comment and save it in y\n",
    "    good_labels = [[1, 0, 0] for _ in good_examples]\n",
    "    bad_labels = [[0, 1, 0] for _ in bad_examples]\n",
    "    mid_labels = [[0, 0, 1] for _ in mid_examples]\n",
    "    y = np.concatenate([good_labels, bad_labels, mid_labels], 0)\n",
    "    return [x_text, y]\n",
    "\n",
    "# 机器学习模式下的读取到的数据集\n",
    "def ml_load_data_and_labels(good_data_file, bad_data_file, mid_data_file):\n",
    "    #load reviews and save them in the list\n",
    "    good_examples = list(open(good_data_file, \"r\", encoding='utf-8').readlines())\n",
    "    good_examples = [s.strip() for s in good_examples]\n",
    "    bad_examples = list(open(bad_data_file, \"r\", encoding='utf-8').readlines())\n",
    "    bad_examples = [s.strip() for s in bad_examples]\n",
    "    mid_examples = list(open(mid_data_file, \"r\", encoding='utf-8').readlines())\n",
    "    mid_examples = [s.strip() for s in mid_examples]\n",
    "\n",
    "    #Call the clean () and sent_filter () functions to process the comments, save them in the x_text list\n",
    "    good_examples = [clean(sent) for sent in good_examples]\n",
    "    bad_examples = [clean(sent) for sent in bad_examples]\n",
    "    mid_examples = [clean(sent) for sent in mid_examples]\n",
    "\n",
    "    good_examples = [i.strip() for i in good_examples]\n",
    "    bad_examples = [i.strip() for i in bad_examples]\n",
    "    mid_examples = [i.strip() for i in mid_examples]\n",
    "\n",
    "    good_examples = sent_filter(good_examples)\n",
    "    bad_examples = sent_filter(bad_examples)\n",
    "    mid_examples = sent_filter(mid_examples)\n",
    "\n",
    "    x_text = good_examples + bad_examples + mid_examples\n",
    "\n",
    "    #Add a label for each comment and save it in y\n",
    "    good_labels = [0 for _ in good_examples]\n",
    "    bad_labels = [1 for _ in bad_examples]\n",
    "    mid_labels = [2 for _ in mid_examples]\n",
    "    y = np.concatenate([good_labels, bad_labels, mid_labels], 0)\n",
    "    return [x_text, y]\n",
    "\n",
    "# when you use tensorflow, you need to generate batches yourself, this function may helpe you\n",
    "def batch_iter(data, batch_size, num_epochs, shuffle=True):\n",
    "    \"\"\"\n",
    "    Generates a batch iterator for a dataset.\n",
    "    \"\"\"\n",
    "    data = np.array(data)\n",
    "    data_size = len(data)\n",
    "    num_batches_per_epoch = int((len(data)-1)/batch_size) + 1\n",
    "    for epoch in range(num_epochs):\n",
    "        # Shuffle the data at each epoch\n",
    "        if shuffle:\n",
    "            shuffle_indices = np.random.permutation(np.arange(data_size))\n",
    "            shuffled_data = data[shuffle_indices]\n",
    "        else:\n",
    "            shuffled_data = data\n",
    "        for batch_num in range(num_batches_per_epoch):\n",
    "            start_index = batch_num * batch_size\n",
    "            end_index = min((batch_num + 1) * batch_size, data_size)\n",
    "            yield shuffled_data[start_index:end_index]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 读取数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0 0 0 ... 2 2 2]\n"
     ]
    }
   ],
   "source": [
    "good_data_file = \"./data/good_cut_jieba.txt\"\n",
    "bad_data_file = \"./data/bad_cut_jieba.txt\"\n",
    "mid_data_file = \"./data/mid_cut_jieba.txt\"\n",
    "x_text, y = ml_load_data_and_labels(good_data_file, bad_data_file, mid_data_file)\n",
    "print(y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "15648\n",
      "15648\n"
     ]
    }
   ],
   "source": [
    "print(len(y))\n",
    "print(len(x_text))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "data_dict={\"raw_words\":x_text,\"target\":y}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style>\n",
       "    .dataframe thead tr:only-child th {\n",
       "        text-align: right;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: left;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>raw_words</th>\n",
       "      <th>target</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>整个 感觉 除了 送货 师傅 新手 来晚 了 外 ， 其他 都 很 好</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>京东 自营 很 不错 ，</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>外观 手感 使用 都 不错 ， 好 ！</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>相信 这个 牌子 ， 没有 验证 过 保修期</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>好 ， 售后服务 到位 ， 能 解决 使用 中 遇见 的 问题 。</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                             raw_words  target\n",
       "0  整个 感觉 除了 送货 师傅 新手 来晚 了 外 ， 其他 都 很 好       0\n",
       "1                         京东 自营 很 不错 ，       0\n",
       "2                  外观 手感 使用 都 不错 ， 好 ！       0\n",
       "3               相信 这个 牌子 ， 没有 验证 过 保修期       0\n",
       "4    好 ， 售后服务 到位 ， 能 解决 使用 中 遇见 的 问题 。       0"
      ]
     },
     "execution_count": 52,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_all=pd.DataFrame(data_dict)\n",
    "df_all.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 数据分割\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "# 数据管道\n",
    "from sklearn.pipeline import Pipeline,make_pipeline\n",
    "# 数据分割\n",
    "x_train, x_test, y_train, y_test = train_test_split(x_text, y, test_size=0.2, random_state=2017)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "data_dict1={\"raw_words\":x_train,\"target\":y_train}\n",
    "data_dict2={\"raw_words\":x_test}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "df_train=pd.DataFrame(data_dict1)\n",
    "df_train.head()\n",
    "# 为fastnlp做准备\n",
    "df_train.to_csv('train.txt',sep='\\t', index=False,header=None,encoding='utf-8')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "df_test=pd.DataFrame(data_dict2)\n",
    "df_test.head()\n",
    "# 为fastnlp做准备\n",
    "df_test.to_csv('test.txt',sep='\\t', index=False,header=None,encoding='utf-8')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 机器学习方案"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import sklearn\n",
    "#机器学习算法模型\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.tree import DecisionTreeClassifier\n",
    "from sklearn.ensemble import RandomForestClassifier,BaggingClassifier,AdaBoostClassifier\n",
    "from sklearn.svm import SVC,LinearSVC\n",
    "from sklearn.naive_bayes import BernoulliNB,MultinomialNB,GaussianNB\n",
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "# 特征提取\n",
    "from sklearn.feature_extraction.text import CountVectorizer,TfidfVectorizer\n",
    "from sklearn.model_selection import train_test_split\n",
    "#Pipeline 使用一系列 (key, value) 键值对来构建,其中 key 是你给这个步骤起的名字， value 是一个评估器对象:\n",
    "from sklearn.pipeline import Pipeline\n",
    "#准确率，精确率，召回率，f1\n",
    "from sklearn.metrics import accuracy_score,precision_score,recall_score,f1_score,classification_report\n",
    "import xgboost as xgb\n",
    "import joblib"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 读取数据\n",
    "def readFile(path):\n",
    "    with open(path, 'r', errors='ignore') as file:  # 文档中编码有些问题，所有用errors过滤错误\n",
    "        content = file.read()\n",
    "        return content"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 读取停用词\n",
    "stwlist=[line.strip() for line in open('stopword.txt','r',encoding='utf-8').readlines()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# pipe = make_pipeline(CountVectorizer(), LogisticRegression())     \n",
    "# # 这里是后台优化回头再说\n",
    "# param_grid = [{'logisticregression__C': [1, 10, 100, 1000]}\n",
    "# gs = GridSearchCV(pipe, param_grid)\n",
    "# gs.fit(X, y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 文本向量化工具"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 创建各类cv=CountVectorizer()和tf_idf工具\n",
    "cv=CountVectorizer(min_df=3,\n",
    "                      max_df=0.5,\n",
    "                      ngram_range=(1,2),\n",
    "                      stop_words = stwlist)\n",
    "tdf=TfidfVectorizer()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "#%%\n",
    "# Count vectoriser --> LogisticRegression()\n",
    "\n",
    "# 分类模型\n",
    "\n",
    "#1.逻辑回归\n",
    "lr=LogisticRegression()\n",
    "# 贝叶斯\n",
    "#2.多项式贝叶斯\n",
    "mb=MultinomialNB()\n",
    "gb=GaussianNB()\n",
    "#3.伯努利贝叶斯\n",
    "bb=BernoulliNB()\n",
    "# 支持向量机\n",
    "#4.支持向量机\n",
    "svc=SVC(kernel='rbf')\n",
    "svc1=SVC(kernel='linear')\n",
    "svc2=SVC(kernel='poly')\n",
    "svc3=SVC(kernel='sigmoid')\n",
    "\n",
    "#5.\n",
    "linearsvc=LinearSVC()\n",
    "#6.决策树\n",
    "dtc=DecisionTreeClassifier(random_state=22)\n",
    "#7.随机森林\n",
    "rfc=RandomForestClassifier(random_state=22)\n",
    "#9.KNN分类器\n",
    "knn=KNeighborsClassifier()\n",
    "\n",
    "modelList=[lr,mb,bb,svc,svc1,svc2,svc3,linearsvc,dtc,rfc,knn]\n",
    "\n",
    "#11个模型\n",
    "m_len=len(modelList)\n",
    "\n",
    "# # 形成 9个模型 2个提取特征 5个指标  \n",
    "# # 提取特征分类器\n",
    "textVectoriser=[cv,tdf]\n",
    "textv_len=len(textVectoriser)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "new_ticks = []\n",
    "name=[]\n",
    "# modelNamelist=['逻辑回归','多项式贝叶斯','伯努利贝叶斯','高斯贝叶斯','RBF核SVM'\n",
    "#                ,'线性核SVM','多项式核SVM','sigmoid核SVM'\n",
    "#                ,'线性分类SVM','决策树','随机森林','KNN']\n",
    "# modelNamelist2=['lr','mb','gb','bb','svc','svc1','svc2','svc3','l'+'\\n'+'svc','dtc','rfc','knn']\n",
    "modelNamelist=['逻辑回归','多项式贝叶斯','伯努利贝叶斯','RBF核SVM'\n",
    "               ,'线性核SVM','多项式核SVM','sigmoid核SVM'\n",
    "               ,'线性分类SVM','决策树','随机森林','KNN']\n",
    "modelNamelist2=['lr','mb','bb','svc','svc1','svc2','svc3','l'+'\\n'+'svc','dtc','rfc','knn']\n",
    "# textVectorNamelist = ['词袋','TDF']\n",
    "for i in range(m_len):\n",
    "        new_ticks.append([modelNamelist2[i]])\n",
    "        name.append(modelNamelist[i])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "name_dict={\"name\":modelNamelist,\"model\":modelList}\n",
    "label_dict={\"name\":modelNamelist2,\"model\":modelList}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "accuracy_score_list=[]\n",
    "precision_score_list=[]\n",
    "recall_score_list=[]\n",
    "f1_score_list=[]\n",
    "classification_report_list=[]\n",
    "modelClass=[]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 记录结果\n",
    "result=pd.DataFrame()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CountVectorizer(analyzer='word', binary=False, decode_error='strict',\n",
      "                dtype=<class 'numpy.int64'>, encoding='utf-8', input='content',\n",
      "                lowercase=True, max_df=0.5, max_features=None, min_df=3,\n",
      "                ngram_range=(1, 2), preprocessor=None,\n",
      "                stop_words=['\\ufeff,', '?', '、', '。', '“', '”', '《', '》', '！',\n",
      "                            '，', '：', '；', '？', '人民', '#', '###', '啊', '阿', '哎',\n",
      "                            '哎呀', '哎哟', '唉', '俺', '俺们', '按', '按照', '吧', '吧哒',\n",
      "                            '把', '罢了', ...],\n",
      "                strip_accents=None, token_pattern='(?u)\\\\b\\\\w\\\\w+\\\\b',\n",
      "                tokenizer=None, vocabulary=None)\n"
     ]
    }
   ],
   "source": [
    "tracv=CountVectorizer().fit_transform(x_train)\n",
    "print(cv)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\anaconda\\lib\\site-packages\\sklearn\\feature_extraction\\text.py:385: UserWarning: Your stop_words may be inconsistent with your preprocessing. Tokenizing the stop words generated tokens ['ain', 'aren', 'couldn', 'didn', 'doesn', 'don', 'hadn', 'hasn', 'haven', 'isn', 'll', 'mon', 'shouldn', 've', 'wasn', 'weren', 'won', 'wouldn'] not in stop_words.\n",
      "  'stop_words.' % sorted(inconsistent))\n",
      "C:\\anaconda\\lib\\site-packages\\sklearn\\linear_model\\_logistic.py:939: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html.\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  extra_warning_msg=_LOGISTIC_SOLVER_CONVERGENCE_MSG)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "======================================================================================================================================================\n",
      "当前模型是： 逻辑回归 当前文本向量化是 词袋 当前准确率是： 0.66933\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\anaconda\\lib\\site-packages\\sklearn\\linear_model\\_logistic.py:939: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html.\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  extra_warning_msg=_LOGISTIC_SOLVER_CONVERGENCE_MSG)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "======================================================================================================================================================\n",
      "当前模型是： 逻辑回归 当前文本向量化是 TF-IDF 当前准确率是： 0.71406\n",
      "======================================================================================================================================================\n",
      "当前模型是： 多项式贝叶斯 当前文本向量化是 词袋 当前准确率是： 0.68275\n",
      "======================================================================================================================================================\n",
      "当前模型是： 多项式贝叶斯 当前文本向量化是 TF-IDF 当前准确率是： 0.71022\n",
      "======================================================================================================================================================\n",
      "当前模型是： 伯努利贝叶斯 当前文本向量化是 词袋 当前准确率是： 0.64249\n",
      "======================================================================================================================================================\n",
      "当前模型是： 伯努利贝叶斯 当前文本向量化是 TF-IDF 当前准确率是： 0.62492\n",
      "======================================================================================================================================================\n",
      "当前模型是： RBF核SVM 当前文本向量化是 词袋 当前准确率是： 0.66773\n",
      "======================================================================================================================================================\n",
      "当前模型是： RBF核SVM 当前文本向量化是 TF-IDF 当前准确率是： 0.71821\n",
      "======================================================================================================================================================\n",
      "当前模型是： 线性核SVM 当前文本向量化是 词袋 当前准确率是： 0.65847\n",
      "======================================================================================================================================================\n",
      "当前模型是： 线性核SVM 当前文本向量化是 TF-IDF 当前准确率是： 0.71374\n",
      "======================================================================================================================================================\n",
      "当前模型是： 多项式核SVM 当前文本向量化是 词袋 当前准确率是： 0.46102\n",
      "======================================================================================================================================================\n",
      "当前模型是： 多项式核SVM 当前文本向量化是 TF-IDF 当前准确率是： 0.68179\n",
      "======================================================================================================================================================\n",
      "当前模型是： sigmoid核SVM 当前文本向量化是 词袋 当前准确率是： 0.65176\n",
      "======================================================================================================================================================\n",
      "当前模型是： sigmoid核SVM 当前文本向量化是 TF-IDF 当前准确率是： 0.70192\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\anaconda\\lib\\site-packages\\sklearn\\svm\\_base.py:947: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
      "  \"the number of iterations.\", ConvergenceWarning)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "======================================================================================================================================================\n",
      "当前模型是： 线性分类SVM 当前文本向量化是 词袋 当前准确率是： 0.64505\n",
      "======================================================================================================================================================\n",
      "当前模型是： 线性分类SVM 当前文本向量化是 TF-IDF 当前准确率是： 0.69904\n",
      "======================================================================================================================================================\n",
      "当前模型是： 决策树 当前文本向量化是 词袋 当前准确率是： 0.57923\n",
      "======================================================================================================================================================\n",
      "当前模型是： 决策树 当前文本向量化是 TF-IDF 当前准确率是： 0.59585\n",
      "======================================================================================================================================================\n",
      "当前模型是： 随机森林 当前文本向量化是 词袋 当前准确率是： 0.64409\n"
     ]
    }
   ],
   "source": [
    "accuracy_score_list=[]\n",
    "# 莫名其妙搞不出来\n",
    "for i in range(m_len):\n",
    "    for j in range(textv_len):\n",
    "#         pipeline =make_pipeline(textVectoriser[j], modelList[i]) \n",
    "        train_vec=textVectoriser[j].fit_transform(x_train)\n",
    "        test_vec=textVectoriser[j].transform(x_test)\n",
    "#         train_vec_dense=train_vec.todense()\n",
    "#         test_vec_dense=test_vec.todense()\n",
    "        modelList[i].fit(train_vec,y_train)\n",
    "#         print(pipeline)\n",
    "#         pred=pipeline.predict(x_test)\n",
    "        pred=modelList[i].predict(test_vec)\n",
    "#         modelList[i].fit(x_train,y_train)  \n",
    "#         pred=modelList[i].predict(x_test)\n",
    "#         pred=pipeline.predict(np.array(X_test).reshape(-1,1))\n",
    "        print('='*150)\n",
    "        if(j==0):\n",
    "            print('当前模型是：',modelNamelist[i],'当前文本向量化是','词袋',\"当前准确率是：\",round(accuracy_score(y_test,pred),5))\n",
    "        if(j==1):\n",
    "            print('当前模型是：',modelNamelist[i],'当前文本向量化是','TF-IDF',\"当前准确率是：\",round(accuracy_score(y_test,pred),5))\n",
    "#         fpr[i], tpr[i], _ = roc_curve(y_test, y_pred)\n",
    "#         roc_auc[i] = auc(fpr[i], tpr[i])\n",
    "        accuracy_score_list.append(round(accuracy_score(y_test,pred),5))\n",
    "    \n",
    "#准确率这些是不支持二分类以上的分类的\n",
    "#         precision_score_list.append(precision_score(y_test,pred))\n",
    "#         f1_score_list.append(f1_score(y_test,pred))\n",
    "        \n",
    "#         recall_score_list.append(f1_score(y_test,pred))\n",
    "        \n",
    "#         #获取标签与最后结果\n",
    "#         fpr[i], tpr[i], _ = roc_curve(y_test, pred)\n",
    "        \n",
    "#         roc_auc[i] = auc(fpr[i], tpr[i])\n",
    "        \n",
    "#         fpr, tpr, thresholds = roc_curve(y_test, pred, pos_label=2)\n",
    "# #         fpr,tpr,thresholds=roc_curve(y_test,y_0)  #计算fpr,tpr,thresholds\n",
    "# #         auc=roc_auc_score(y_test,y_0) #计算auc\n",
    "#         #画曲线图\n",
    "#         plt.figure()\n",
    "#         plt.plot(fpr,tpr)\n",
    "#         plt.title('$ROC curve$')\n",
    "#         plt.show()\n",
    "\n",
    "#         classification_report_list.append(classification_report(y_test,pred))\n",
    "#         modelClass.append(pipeline) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\anaconda\\lib\\site-packages\\sklearn\\feature_extraction\\text.py:385: UserWarning: Your stop_words may be inconsistent with your preprocessing. Tokenizing the stop words generated tokens ['ain', 'aren', 'couldn', 'didn', 'doesn', 'don', 'hadn', 'hasn', 'haven', 'isn', 'll', 'mon', 'shouldn', 've', 'wasn', 'weren', 'won', 'wouldn'] not in stop_words.\n",
      "  'stop_words.' % sorted(inconsistent))\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "======================================================================================================================================================\n"
     ]
    },
    {
     "ename": "NameError",
     "evalue": "name 'pred' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-23-c5e79d4a1c97>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m()\u001b[0m\n\u001b[0;32m     44\u001b[0m         \u001b[0mprint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m'='\u001b[0m\u001b[1;33m*\u001b[0m\u001b[1;36m150\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     45\u001b[0m         \u001b[1;32mif\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mj\u001b[0m\u001b[1;33m==\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 46\u001b[1;33m             \u001b[0mprint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m'当前模型是：xgboost'\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;34m'当前文本向量化是'\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;34m'词袋'\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;34m\"当前准确率是：\"\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mround\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0maccuracy_score\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0my_test\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mpred\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;36m5\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     47\u001b[0m         \u001b[1;32mif\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mj\u001b[0m\u001b[1;33m==\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     48\u001b[0m             \u001b[0mprint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m'当前模型是：xgboost'\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;34m'当前文本向量化是'\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;34m'TF-IDF'\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;34m\"当前准确率是：\"\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mround\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0maccuracy_score\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0my_test\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mpred\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;36m5\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mNameError\u001b[0m: name 'pred' is not defined"
     ]
    }
   ],
   "source": [
    "import xgboost as xgb\n",
    "# 算法参数\n",
    "# 应该是用于分类\n",
    "params = {\n",
    "    'booster': 'gbtree',\n",
    "    'objective': 'multi:softmax',\n",
    "    'num_class': 3,\n",
    "    'gamma': 0.1,\n",
    "    'max_depth': 6,\n",
    "    'lambda': 2,\n",
    "    'subsample': 0.7,\n",
    "    'colsample_bytree': 0.75,\n",
    "    'min_child_weight': 3,\n",
    "    'silent': 0,\n",
    "    'eta': 0.1,\n",
    "    'seed': 1,\n",
    "    'nthread': 4,\n",
    "}\n",
    "\n",
    "\n",
    "for i in range(m_len):\n",
    "    for j in range(textv_len):\n",
    "#         pipeline =make_pipeline(textVectoriser[j], modelList[i]) \n",
    "        train_vec=textVectoriser[j].fit_transform(x_train)\n",
    "        test_vec=textVectoriser[j].transform(x_test)\n",
    "        plst = params.items()\n",
    "\n",
    "        dtrain = xgb.DMatrix(train_vec, y_train) # 生成数据集格式\n",
    "\n",
    "        num_rounds = 500\n",
    "\n",
    "        model = xgb.train(plst, dtrain, num_rounds) # xgboost模型训练\n",
    "\n",
    "\n",
    "        # 对测试集进行预测\n",
    "        dtest = xgb.DMatrix(test_vec)\n",
    "\n",
    "        pred = model.predict(dtest)\n",
    "\n",
    "\n",
    "        # 计算准确率\n",
    "        accuracy = accuracy_score(y_test,y_pred)\n",
    "#         print('当前是xgboost')\n",
    "        print('='*150)\n",
    "        if(j==0):\n",
    "            print('当前模型是：xgboost','当前文本向量化是','词袋',\"当前准确率是：\",round(accuracy_score(y_test,pred),5))\n",
    "        if(j==1):\n",
    "            print('当前模型是：xgboost','当前文本向量化是','TF-IDF',\"当前准确率是：\",round(accuracy_score(y_test,pred),5))\n",
    "#         fpr[i], tpr[i], _ = roc_curve(y_test, y_pred)\n",
    "#         roc_auc[i] = auc(fpr[i], tpr[i])\n",
    "        modelNamelist.append('xgboost')\n",
    "        accuracy_score_list.append(round(accuracy,5))\n",
    "        print(\"accuarcy: %.2f%%\" % (accuracy*100.0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# result[\"模型名称\"]=modelNamelist\n",
    "# result[\"准确率\"]=accuracy_score_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# result.to_csv('机器学习不同模型准确率.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from pylab import *\n",
    "mpl.rcParams['font.sans-serif'] = ['SimHei']\n",
    "#plot根据列表绘制出有意义的图形，linewidth是图形线宽，可省略\n",
    "# plt.plot(input_values,squares,linewidth=5)\n",
    "plt.figure(figsize=(12,5),dpi=80)\n",
    "plt.bar(range(len(accuracy_score_list)),accuracy_score_list,linewidth=5)\n",
    "#设置图标标题\n",
    "plt.title(\"不同管道模型准确率\",fontsize = 24)\n",
    "#设置坐标轴标签\n",
    "plt.xlabel(\"模型类型\",fontsize = 0.2)\n",
    "plt.ylabel(\"准确率\",fontsize = 0.5)\n",
    "#设置刻度标记的大小\n",
    "plt.tick_params(axis='both',labelsize = 14)\n",
    "#打开matplotlib查看器，并显示绘制图形\n",
    "#这是一半\n",
    "plt.xticks(range(new_ticks),new_ticks)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 深度学习方案"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 保证映射后结构一样\n",
    "from keras.preprocessing.sequence import pad_sequences\n",
    "# 文本预处理\n",
    "from keras.preprocessing.text import Tokenizer\n",
    "# 将类别映射成需要的格式\n",
    "from keras.utils.np_utils import to_categorical\n",
    "\n",
    "# 这个是连接层\n",
    "from keras.layers.merge import concatenate\n",
    "\n",
    "# 搭建模型\n",
    "from keras.models import Sequential, Model\n",
    "\n",
    "# 这个是层的搭建\n",
    "from keras.layers import Dense, Embedding, Activation, Input\n",
    "\n",
    "from keras.layers import Convolution1D, Flatten, Dropout, MaxPool1D\n",
    "\n",
    "from keras.layers import  BatchNormalization\n",
    "from keras.layers import Conv1D,MaxPooling1D\n",
    "\n",
    "\n",
    "# 导入使用到的库\n",
    "from keras.preprocessing.sequence import pad_sequences\n",
    "from keras.preprocessing.text import Tokenizer\n",
    "from keras.layers.merge import concatenate\n",
    "from keras.models import Sequential, Model\n",
    "from keras.layers import Dense, Embedding, Activation, merge, Input, Lambda, Reshape\n",
    "from keras.layers import Convolution1D, Flatten, Dropout, MaxPool1D, GlobalAveragePooling1D\n",
    "from keras.layers import LSTM, GRU, TimeDistributed, Bidirectional\n",
    "from keras.utils.np_utils import to_categorical\n",
    "from keras import initializers\n",
    "from keras import backend as K\n",
    "from keras.engine.topology import Layer\n",
    "# from sklearn.naive_bayes import MultinomialNB\n",
    "# from sklearn.linear_model import SGDClassifier\n",
    "# from sklearn.feature_extraction.text import TfidfVectorizer\n",
    "# import pandas as pd\n",
    "# import numpy as np\n",
    "\n",
    "\n",
    "# 数据处理\n",
    "# from data_helper_ml import load_data_and_labels\n",
    "\n",
    "# 数据可视化\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 文本标签分类数量\n",
    "NUM_CLASS=3\n",
    "# 输入维度\n",
    "INPUT_SIZE=64\n",
    "# # 序列对齐文本数据\n",
    "# LENTH=100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Tokenizer是一个用于向量化文本，或将文本转换为序列\n",
    "tokenizer = Tokenizer(filters='!\"#$%&()*+,-./:;<=>?@[\\\\]^_`{|}~\\t\\n',lower=True,split=\" \")\n",
    "tokenizer.fit_on_texts(x_text)\n",
    "vocab = tokenizer.word_index\n",
    "\n",
    "#映射成数字\n",
    "x_train_word_ids = tokenizer.texts_to_sequences(x_train)\n",
    "x_test_word_ids = tokenizer.texts_to_sequences(x_test)\n",
    "#让他共同化\n",
    "x_train_padded_seqs = pad_sequences(x_train_word_ids, maxlen=INPUT_SIZE)\n",
    "x_test_padded_seqs = pad_sequences(x_test_word_ids, maxlen=INPUT_SIZE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# CNN模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def cnn():\n",
    "    model = Sequential()\n",
    "    model.add(Embedding(len(vocab) + 1, 300, input_length=INPUT_SIZE)) #使用Embeeding层将每个词编码转换为词向量\n",
    "    model.add(Conv1D(256, 5, padding='same'))\n",
    "    model.add(MaxPooling1D(3, 3, padding='same'))\n",
    "    model.add(Conv1D(128, 5, padding='same'))\n",
    "    model.add(MaxPooling1D(3, 3, padding='same'))\n",
    "    model.add(Conv1D(64, 3, padding='same'))\n",
    "    model.add(Flatten())\n",
    "    model.add(Dropout(0.1))\n",
    "    model.add(BatchNormalization())  # (批)规范化层\n",
    "    model.add(Dense(256, activation='relu'))\n",
    "    model.add(Dropout(0.1))\n",
    "    model.add(Dense(NUM_CLASS, activation='softmax'))\n",
    "    model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# textCNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def textCNN():\n",
    "    main_input = Input(shape=(64,), dtype='float64')\n",
    "    # 词嵌入（使用预训练的词向量）\n",
    "    embedder = Embedding(len(vocab) + 1, 300, input_length=INPUT_SIZE, trainable=False)\n",
    "\n",
    "    embed = embedder(main_input)\n",
    "\n",
    "    # 词窗大小分别为3,4,5\n",
    "    cnn1 = Conv1D(256, 3, padding='same', strides=1, activation='relu')(embed)\n",
    "\n",
    "    cnn1 = MaxPooling1D(pool_size=48)(cnn1)\n",
    "\n",
    "    cnn2 = Conv1D(256, 4, padding='same', strides=1, activation='relu')(embed)\n",
    "\n",
    "    cnn2 = MaxPooling1D(pool_size=47)(cnn2)\n",
    "\n",
    "    cnn3 = Conv1D(256, 5, padding='same', strides=1, activation='relu')(embed)\n",
    "\n",
    "    cnn3 = MaxPooling1D(pool_size=46)(cnn3)\n",
    "\n",
    "    # 合并三个模型的输出向量\n",
    "    cnn = concatenate([cnn1, cnn2, cnn3], axis=-1)\n",
    "\n",
    "    flat = Flatten()(cnn)\n",
    "\n",
    "    drop = Dropout(0.2)(flat)\n",
    "\n",
    "    main_output = Dense(NUM_CLASS, activation='softmax')(drop)\n",
    "\n",
    "    model = Model(inputs=main_input, outputs=main_output)\n",
    "\n",
    "    model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 使用Word2Vec词向量的TextCNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# w2v_model=Word2Vec.load('sentiment_analysis/w2v_model.pkl')\n",
    "# # 预训练的词向量中没有出现的词用0向量表示\n",
    "# embedding_matrix = np.zeros((len(vocab) + 1, 300))\n",
    "# for word, i in vocab.items():\n",
    "#     try:\n",
    "#         embedding_vector = w2v_model[str(word)]\n",
    "#         embedding_matrix[i] = embedding_vector\n",
    "#     except KeyError:\n",
    "#         continue\n",
    "        \n",
    "#  #构建TextCNN模型\n",
    "# def TextCNN_model_2():\n",
    "#     # 模型结构：词嵌入-卷积池化*3-拼接-全连接-dropout-全连接\n",
    "#     main_input = Input(shape=(INPUT_SIZE,), dtype='float64')\n",
    "#     # 词嵌入（使用预训练的词向量）\n",
    "#     embedder = Embedding(len(vocab) + 1, 300, input_length=INPUT_SIZE, weights=[embedding_matrix], trainable=False)\n",
    "#     #embedder = Embedding(len(vocab) + 1, 300, input_length=50, trainable=False)\n",
    "#     embed = embedder(main_input)\n",
    "#     # 词窗大小分别为3,4,5\n",
    "#     cnn1 = Conv1D(256, 3, padding='same', strides=1, activation='relu')(embed)\n",
    "#     cnn1 = MaxPooling1D(pool_size=38)(cnn1)\n",
    "#     cnn2 = Conv1D(256, 4, padding='same', strides=1, activation='relu')(embed)\n",
    "#     cnn2 = MaxPooling1D(pool_size=37)(cnn2)\n",
    "#     cnn3 = Conv1D(256, 5, padding='same', strides=1, activation='relu')(embed)\n",
    "#     cnn3 = MaxPooling1D(pool_size=36)(cnn3)\n",
    "#     # 合并三个模型的输出向量\n",
    "#     cnn = concatenate([cnn1, cnn2, cnn3], axis=-1)\n",
    "#     flat = Flatten()(cnn)\n",
    "#     drop = Dropout(0.2)(flat)\n",
    "#     main_output = Dense(3, activation='softmax')(drop)\n",
    "#     model = Model(inputs=main_input, outputs=main_output)\n",
    "#     model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])\n",
    " \n",
    "#     one_hot_labels = keras.utils.to_categorical(y_train, num_classes=NUM_CLASS)  # 将标签转换为one-hot编码\n",
    "# #     model.fit(x_train_padded_seqs, one_hot_labels, batch_size=800, epochs=20)\n",
    "# #     #y_test_onehot = keras.utils.to_categorical(y_test, num_classes=3)  # 将标签转换为one-hot编码\n",
    "# #     result = model.predict(x_test_padded_seqs)  # 预测样本属于每个类别的概率\n",
    "# #     result_labels = np.argmax(result, axis=1)  # 获得最大概率对应的标签\n",
    "# #     y_predict = list(map(str, result_labels))\n",
    "# #     print('准确率', metrics.accuracy_score(y_test, y_predict))\n",
    "# #     print('平均f1-score:', metrics.f1_score(y_test, y_predict, average='weighted'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# RNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def rnn():\n",
    "        # 模型结构：词嵌入-LSTM-全连接\n",
    "    model = Sequential()\n",
    "    model.add(Embedding(len(vocab)+1, 300, input_length=INPUT_SIZE))\n",
    "    model.add(LSTM(256, dropout=0.2, recurrent_dropout=0.1))\n",
    "    model.add(Dense(NUM_CLASSM, activation='softmax'))\n",
    "    model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Bi-GRU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def digru():\n",
    "    # 模型结构：词嵌入-双向GRU*2-全连接\n",
    "    model = Sequential()\n",
    "    # 64是序列号\n",
    "    model.add(Embedding(len(vocab)+1, 300, input_length=INPUT_SIZE))\n",
    "    model.add(Bidirectional(GRU(256, dropout=0.2, recurrent_dropout=0.1, return_sequences=True)))\n",
    "    model.add(Bidirectional(GRU(256, dropout=0.2, recurrent_dropout=0.1)))\n",
    "    model.add(Dense(NUM_CLASSM_C, activation='softmax'))\n",
    "    model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# CNN+RNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "##串联"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def clstm():\n",
    "    # 模型结构：词嵌入-卷积池化-GRU*2-全连接\n",
    "    model = Sequential()\n",
    "    model.add(Embedding(len(vocab)+1, 300, input_length=INPUT_SIZE))\n",
    "    model.add(Convolution1D(256, 3, padding='same', strides = 1))\n",
    "    model.add(Activation('relu'))\n",
    "    model.add(MaxPool1D(pool_size=2))\n",
    "    model.add(GRU(256, dropout=0.2, recurrent_dropout=0.1, return_sequences = True))\n",
    "    model.add(GRU(256, dropout=0.2, recurrent_dropout=0.1))\n",
    "    model.add(Dense(NUM_CLASS, activation='softmax'))\n",
    "    model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 并联"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def blstm():\n",
    "    # 模型结构：词嵌入-卷积池化-全连接 ---拼接-全连接\n",
    "    #                -双向GRU-全连接\n",
    "    main_input = Input(shape=(INPUT_SIZE,), dtype='float64')\n",
    "    embed = Embedding(len(vocab)+1, 300, input_length=INPUT_SIZE)(main_input)\n",
    "    cnn = Convolution1D(256, 3, padding='same', strides = 1, activation='relu')(embed)\n",
    "    cnn = MaxPool1D(pool_size=4)(cnn)\n",
    "    cnn = Flatten()(cnn)\n",
    "    cnn = Dense(256)(cnn)\n",
    "    rnn = Bidirectional(GRU(256, dropout=0.2, recurrent_dropout=0.1))(embed)\n",
    "    rnn = Dense(256)(rnn)\n",
    "    con = concatenate([cnn,rnn], axis=-1)\n",
    "    main_output = Dense(NUM_CLASS, activation='softmax')(con)\n",
    "    model = Model(inputs = main_input, outputs = main_output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# fasttext"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 模型结构：词嵌入(n-gram)-最大化池化-全连接\n",
    "# 生成n-gram组合的词(以3为例)\n",
    "ngram = 3\n",
    "# 将n-gram词加入到词表\n",
    "def create_ngram(sent, ngram_value):\n",
    "    return set(zip(*[sent[i:] for i in range(ngram_value)]))\n",
    "ngram_set = set()\n",
    "for sentence in x_train_padded_seqs:\n",
    "    for i in range(2, ngram+1):\n",
    "        set_of_ngram = create_ngram(sentence, i)\n",
    "        ngram_set.update(set_of_ngram)\n",
    "        \n",
    "# 给n-gram词汇编码\n",
    "start_index = len(vocab) + 2\n",
    "token_indice = {v: k + start_index for k, v in enumerate(ngram_set)} # 给n-gram词汇编码\n",
    "indice_token = {token_indice[k]: k for k in token_indice}\n",
    "max_features = np.max(list(indice_token.keys())) + 1\n",
    "# 将n-gram词加入到输入文本的末端\n",
    "def add_ngram(sequences, token_indice, ngram_range):\n",
    "    new_sequences = []\n",
    "    for sent in sequences:\n",
    "        new_list = sent[:]\n",
    "        for i in range(len(new_list) - ngram_range + 1):\n",
    "            for ngram_value in range(2, ngram_range + 1):\n",
    "                ngram = tuple(new_list[i:i + ngram_value])\n",
    "                if ngram in token_indice:\n",
    "                    new_list.append(token_indice[ngram])\n",
    "        new_sequences.append(new_list)\n",
    "    return new_sequences\n",
    "\n",
    "\n",
    "x_train = add_ngram(x_train_word_ids, token_indice, ngram)\n",
    "x_test = add_ngram(x_test_word_ids, token_indice, ngram)\n",
    "# x_train = pad_sequences(x_train, maxlen=25)\n",
    "# x_test = pad_sequences(x_test, maxlen=25)\n",
    "x_train_padded_seqs = pad_sequences(x_train_word_ids, maxlen=INPUT_SIZE)\n",
    "x_test_padded_seqs = pad_sequences(x_test_word_ids, maxlen=INPUT_SIZE)\n",
    "def fasttext():\n",
    "    model = Sequential()\n",
    "    model.add(Embedding(max_features, 300, input_length=INPUT_SIZE))\n",
    "    model.add(GlobalAveragePooling1D())\n",
    "    model.add(Dense(NUM_CLASS, activation='softmax'))\n",
    "    model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# keras_bert"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 91,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import codecs, gc\n",
    "import numpy as np\n",
    "from sklearn.model_selection import KFold\n",
    "from keras_bert import load_trained_model_from_checkpoint, Tokenizer\n",
    "from keras.metrics import top_k_categorical_accuracy\n",
    "from keras.layers import *\n",
    "from keras.callbacks import *\n",
    "from keras.models import Model\n",
    "import keras.backend as K\n",
    "from keras.optimizers import Adam\n",
    "from keras.utils import to_categorical"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "maxlen = INPUT_SIZE  #设置序列长度为120，要保证序列长度不超过512\n",
    " \n",
    "#预训练好的模型\n",
    "# 还是放在原有样本中\n",
    "# path=os.path.join(dirpath,config_path)\n",
    "# os.path.join()\n",
    "config_path = 'bert_config.json'\n",
    "# config_path=os.path.join(dirpath,config_path)\n",
    "checkpoint_path = 'bert_model.ckpt'\n",
    "# checkpoint_path=os.path.join(dirpath,checkpoint_path)\n",
    "dict_path = 'vocab.txt'\n",
    "# dict_path=os.path.join(dirpath,checkpoint_path)\n",
    "#将词表中的词编号转换为字典\n",
    "token_dict = {}\n",
    "with codecs.open(dict_path, 'r', 'utf8') as reader:\n",
    "    for line in reader:\n",
    "        token = line.strip()\n",
    "        token_dict[token] = len(token_dict)\n",
    "#重写tokenizer        \n",
    "class OurTokenizer(Tokenizer):\n",
    "    def _tokenize(self, text):\n",
    "        R = []\n",
    "        for c in text:\n",
    "            if c in self._token_dict:\n",
    "                R.append(c)\n",
    "            elif self._is_space(c):\n",
    "                R.append('[unused1]')  # 用[unused1]来表示空格类字符\n",
    "            else:\n",
    "                R.append('[UNK]')  # 不在列表的字符用[UNK]表示\n",
    "        return R\n",
    "tokenizer = OurTokenizer(token_dict)\n",
    "#让每条文本的长度相同，用0填充\n",
    "def seq_padding(X, padding=0):\n",
    "    L = [len(x) for x in X]\n",
    "    ML = max(L)\n",
    "    return np.array([\n",
    "        np.concatenate([x, [padding] * (ML - len(x))]) if len(x) < ML else x for x in X\n",
    "    ])\n",
    "#data_generator只是一种为了节约内存的数据方式\n",
    "class data_generator:\n",
    "    def __init__(self, data, batch_size=32, shuffle=True):\n",
    "        self.data = data\n",
    "        self.batch_size = batch_size\n",
    "        self.shuffle = shuffle\n",
    "        self.steps = len(self.data) // self.batch_size\n",
    "        if len(self.data) % self.batch_size != 0:\n",
    "            self.steps += 1\n",
    " \n",
    "    def __len__(self):\n",
    "        return self.steps\n",
    " \n",
    "    def __iter__(self):\n",
    "        while True:\n",
    "            idxs = list(range(len(self.data)))\n",
    " \n",
    "            if self.shuffle:\n",
    "                np.random.shuffle(idxs)\n",
    " \n",
    "            X1, X2, Y = [], [], []\n",
    "            for i in idxs:\n",
    "                d = self.data[i]\n",
    "                text = d[0][:maxlen]\n",
    "                x1, x2 = tokenizer.encode(first=text)\n",
    "                y = d[1]\n",
    "                X1.append(x1)\n",
    "                X2.append(x2)\n",
    "                Y.append([y])\n",
    "                if len(X1) == self.batch_size or i == idxs[-1]:\n",
    "                    X1 = seq_padding(X1)\n",
    "                    X2 = seq_padding(X2)\n",
    "                    Y = seq_padding(Y)\n",
    "                    yield [X1, X2], Y[:, 0, :]\n",
    "                    [X1, X2, Y] = [], [], []\n",
    "#bert模型设置\n",
    "def build_bert(nclass):\n",
    "    bert_model = load_trained_model_from_checkpoint(config_path, checkpoint_path, seq_len=None)  #加载预训练模型\n",
    " \n",
    "    for l in bert_model.layers:\n",
    "        l.trainable = True\n",
    " \n",
    "    x1_in = Input(shape=(None,))\n",
    "    x2_in = Input(shape=(None,))\n",
    " \n",
    "    x = bert_model([x1_in, x2_in])\n",
    "    x = Lambda(lambda x: x[:, 0])(x) # 取出[CLS]对应的向量用来做分类\n",
    "    p = Dense(nclass, activation='softmax')(x)\n",
    " \n",
    "    model = Model([x1_in, x2_in], p)\n",
    "    model.compile(loss='categorical_crossentropy',\n",
    "                  optimizer=Adam(1e-5),    #用足够小的学习率\n",
    "                  metrics=['accuracy', acc_top2])\n",
    "    print(model.summary())\n",
    "    return model\n",
    "#计算top-k正确率,当预测值的前k个值中存在目标类别即认为预测正确                 \n",
    "def acc_top2(y_true, y_pred):\n",
    "    return top_k_categorical_accuracy(y_true, y_pred, k=2)\n",
    "#训练数据、测试数据和标签转化为模型输入格式\n",
    "DATA_LIST = []\n",
    "for data_row in train_df1.iloc[:].itertuples():\n",
    "    DATA_LIST.append((xtrain, to_categorical(ytraim, NUM_CLASS)))\n",
    "DATA_LIST = np.array(DATA_LIST)\n",
    " \n",
    "DATA_LIST_TEST = []\n",
    "for data_row in test_df1.iloc[:].itertuples():\n",
    "    DATA_LIST_TEST.append((xtest, to_categorical(0, NUM_CLASS)))\n",
    "DATA_LIST_TEST = np.array(DATA_LIST_TEST)\n",
    "#交叉验证训练和测试模型\n",
    "def run_cv(nfold, data, data_labels, data_test):\n",
    "    kf = KFold(n_splits=nfold, shuffle=True, random_state=520).split(data)\n",
    "    train_model_pred = np.zeros((len(data), 3))\n",
    "    test_model_pred = np.zeros((len(data_test), 3))\n",
    " \n",
    "    for i, (train_fold, test_fold) in enumerate(kf):\n",
    "        X_train, X_valid, = data[train_fold, :], data[test_fold, :]\n",
    " \n",
    "        model = build_bert(NUM_CLASS)\n",
    "        early_stopping = EarlyStopping(monitor='val_acc', patience=3)   #早停法，防止过拟合\n",
    "        plateau = ReduceLROnPlateau(monitor=\"val_acc\", verbose=1, mode='max', factor=0.5, patience=2) #当评价指标不在提升时，减少学习率\n",
    "        checkpoint = ModelCheckpoint('./bert_dump/' + str(i) + '.hdf5', monitor='val_acc',verbose=2, save_best_only=True, mode='max', save_weights_only=True) #保存最好的模型\n",
    " \n",
    "        train_D = data_generator(X_train, shuffle=True)\n",
    "        valid_D = data_generator(X_valid, shuffle=True)\n",
    "        test_D = data_generator(data_test, shuffle=False)\n",
    "        #模型训练\n",
    "        model.fit_generator(\n",
    "            train_D.__iter__(),\n",
    "            steps_per_epoch=len(train_D),\n",
    "            epochs=2,\n",
    "            validation_data=valid_D.__iter__(),\n",
    "            validation_steps=len(valid_D),\n",
    "            callbacks=[early_stopping, plateau, checkpoint],\n",
    "        )\n",
    " \n",
    "        # model.load_weights('./bert_dump/' + str(i) + '.hdf5')\n",
    " \n",
    "        # return model\n",
    "        train_model_pred[test_fold, :] = model.predict_generator(valid_D.__iter__(), steps=len(valid_D), verbose=1)\n",
    "        test_model_pred += model.predict_generator(test_D.__iter__(), steps=len(test_D), verbose=1)\n",
    " \n",
    "        del model\n",
    "        gc.collect()   #清理内存\n",
    "        K.clear_session()   #clear_session就是清除一个session\n",
    "        # break\n",
    " \n",
    "    return train_model_pred, test_model_pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# import os\n",
    "\n",
    "# print(os.getcwd()) #打印出当前工作路径\n",
    "\n",
    "# 很容易就崩溃了\n",
    "#n折交叉验证\n",
    "train_model_pred, test_model_pred = run_cv(2, DATA_LIST, None, DATA_LIST_TEST)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential_1\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "embedding_1 (Embedding)      (None, 64, 300)           4594500   \n",
      "_________________________________________________________________\n",
      "conv1d_1 (Conv1D)            (None, 64, 256)           384256    \n",
      "_________________________________________________________________\n",
      "max_pooling1d_1 (MaxPooling1 (None, 22, 256)           0         \n",
      "_________________________________________________________________\n",
      "conv1d_2 (Conv1D)            (None, 22, 128)           163968    \n",
      "_________________________________________________________________\n",
      "max_pooling1d_2 (MaxPooling1 (None, 8, 128)            0         \n",
      "_________________________________________________________________\n",
      "conv1d_3 (Conv1D)            (None, 8, 64)             24640     \n",
      "_________________________________________________________________\n",
      "flatten_1 (Flatten)          (None, 512)               0         \n",
      "_________________________________________________________________\n",
      "dropout_1 (Dropout)          (None, 512)               0         \n",
      "_________________________________________________________________\n",
      "batch_normalization_1 (Batch (None, 512)               2048      \n",
      "_________________________________________________________________\n",
      "dense_1 (Dense)              (None, 256)               131328    \n",
      "_________________________________________________________________\n",
      "dropout_2 (Dropout)          (None, 256)               0         \n",
      "_________________________________________________________________\n",
      "dense_2 (Dense)              (None, 3)                 771       \n",
      "=================================================================\n",
      "Total params: 5,301,511\n",
      "Trainable params: 5,300,487\n",
      "Non-trainable params: 1,024\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model=cnn() #rnn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model.compile(loss='categorical_crossentropy',optimizer='adam',metrics=['accuracy'])\n",
    "one_hot_labels = to_categorical(y_train, num_classes=NUM_CLASS)  # 将标签转换为one-hot编码\n",
    "# one_hot_labels=y_train\n",
    "model.fit(x_train_padded_seqs, one_hot_labels,epochs=5, batch_size=800)\n",
    "y_predict = model.predict_classes(x_test_padded_seqs)  # 预测的是类别，结果就是类别号\n",
    "y_predict = list(map(str, y_predict))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 外挂神器FastNlp方案"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "#导入Pytorch包\n",
    "\n",
    "import torch\n",
    "\n",
    "import torch.nn as nn\n",
    "\n",
    "from fastNLP.io.loader import CSVLoader\n",
    "\n",
    "dataset_loader = CSVLoader(headers=('raw_words','target'), sep='\\t')\n",
    "testset_loader = CSVLoader( headers=['raw_words'],sep='\\t')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 表示将CSV文件中每一行的第一项将填入'raw_words' field，第二项填入'target' field。\n",
    "\n",
    "# 其中项之间由'\\t'分割开来\n",
    "\n",
    "train_path=r'train.txt'\n",
    "\n",
    "test_path=r'test.txt'\n",
    "\n",
    "dataset = dataset_loader._load(train_path)\n",
    "\n",
    "testset = testset_loader._load(test_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.39\n"
     ]
    }
   ],
   "source": [
    "# 将句子分成单词形式, 详见DataSet.apply()方法\n",
    "\n",
    "import jieba\n",
    "\n",
    "from itertools import chain\n",
    "\n",
    "print(jieba.__version__)\n",
    "# from itertools import chain\n",
    "\n",
    "#     '''\n",
    "\n",
    "#     @params:\n",
    "\n",
    "#         data: 数据的列表，列表中的每个元素为 [文本字符串，0/1标签] 二元组\n",
    "\n",
    "#     @return: 切分词后的文本的列表，列表中的每个元素为切分后的词序列\n",
    "\n",
    "#     '''\n",
    "\n",
    "def get_tokenized(data,words=True):\n",
    "    def tokenizer(text):\n",
    "        return [tok for tok in jieba.cut(text, cut_all=False)]\n",
    "    if words:\n",
    "\n",
    "        #按词语进行编码\n",
    "\n",
    "        return tokenizer(data)\n",
    "\n",
    "    else:\n",
    "\n",
    "        #按字进行编码\n",
    "\n",
    "        return [tokenizer(review) for review in data]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-------------------------------------+--------+\n",
      "| raw_words                           | target |\n",
      "+-------------------------------------+--------+\n",
      "| 这个 系统 很 好 用 ， 用 不 习惯... | 0      |\n",
      "| 我 说 你 信 吗 。                   | 2      |\n",
      "| 快递 特变 慢 ， 价钱 比 别的 网...  | 2      |\n",
      "| 用 了 一个 来 月 ， 发现 很多 问... | 1      |\n",
      "| 早上 买 的 4888 草 下午 4500 。...  | 1      |\n",
      "| 一 被 蹂躏 了 N 次千苍百孔 的 破... | 1      |\n",
      "| 我 很 满意 手机 很 好               | 0      |\n",
      "| 没有 想象 中 的 那么 好 ， 总是...  | 2      |\n",
      "| 昨天 买 的 ， 今天 就 便宜 100 ...  | 1      |\n",
      "| 为什么 我 手机 今天 到 了 没有 ...  | 2      |\n",
      "| 特意 用 了 一段时间 才 来 评价 ...  | 0      |\n",
      "| 早就 想 买 了 ， 送人 的 ！ 物流... | 0      |\n",
      "| 缺点 就是 手机 送来 的 时候 没电... | 2      |\n",
      "| 送货 速度 快 ， 手机 漂亮 ， 大...  | 0      |\n",
      "| 晚上 下单 第二天 上午 就 到 ， ...  | 0      |\n",
      "| 不错 打白条 来 的 为什么 凭 我 ...  | 0      |\n",
      "| 物流 这些 还 可以 ， 专门 使用 ...  | 2      |\n",
      "| 检查 了 是 正品 挺 好 的 ， 在 ...  | 0      |\n",
      "| 手机 是 用 过 的 打 客服 电话 一... | 1      |\n",
      "| 收到 手机 使用 了 有 几天 了 ，...  | 0      |\n",
      "| 快递 很 给力 配送员 态度 也好 可... | 1      |\n",
      "| 还好 吧 ， 就是 信号 很差 ， 在...  | 2      |\n",
      "| 手机 是 没 问题 ， 可是 插头 不...  | 1      |\n",
      "| 挺 好 的 ， 就是 太贵 了 ， 苹果... | 2      |\n",
      "| 完美 的 一次 购物 ， 第二天 到 ...  | 0      |\n",
      "| 信号 不太好 ， 心塞 。              | 2      |\n",
      "| 先不说 手机 刚回来 第一天 就 黑...  | 1      |\n",
      "| 手机 还 行 ， 就是 手机 网络 老...  | 1      |\n",
      "| 买 了 之后 没过几天 就 降价 ， ...  | 1      |\n",
      "| 刚 买回来 ， 很 好 用               | 0      |\n",
      "| 屏幕 有 一个点 ， 小 瑕疵 ， 但...  | 2      |\n",
      "| ...                                 | ...    |\n",
      "+-------------------------------------+--------+\n"
     ]
    }
   ],
   "source": [
    "print(dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Building prefix dict from the default dictionary ...\n",
      "WARNING: Logging before flag parsing goes to stderr.\n",
      "I0702 18:13:51.150580 13512 __init__.py:114] Building prefix dict from the default dictionary ...\n",
      "Dumping model to file cache C:\\Users\\ADMINI~1\\AppData\\Local\\Temp\\jieba.cache\n",
      "I0702 18:13:52.657735 13512 __init__.py:148] Dumping model to file cache C:\\Users\\ADMINI~1\\AppData\\Local\\Temp\\jieba.cache\n",
      "Loading model cost 1.706 seconds.\n",
      "I0702 18:13:52.862597 13512 __init__.py:166] Loading model cost 1.706 seconds.\n",
      "Prefix dict has been built successfully.\n",
      "I0702 18:13:52.866595 13512 __init__.py:167] Prefix dict has been built successfully.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+---------------------------+--------+--------------------------+\n",
      "| raw_words                 | target | words                    |\n",
      "+---------------------------+--------+--------------------------+\n",
      "| 这个 系统 很 好 用 ，...  | 0      | ['这个', ' ', '系统',... |\n",
      "| 我 说 你 信 吗 。         | 2      | ['我', ' ', '说', ' '... |\n",
      "| 快递 特变 慢 ， 价钱 ...  | 2      | ['快递', ' ', '特变',... |\n",
      "| 用 了 一个 来 月 ， 发... | 1      | ['用', ' ', '了', ' '... |\n",
      "| 早上 买 的 4888 草 下...  | 1      | ['早上', ' ', '买', '... |\n",
      "| 一 被 蹂躏 了 N 次千苍... | 1      | ['一', ' ', '被', ' '... |\n",
      "| 我 很 满意 手机 很 好...  | 0      | ['我', ' ', '很', ' '... |\n",
      "| 没有 想象 中 的 那么 ...  | 2      | ['没有', ' ', '想象',... |\n",
      "| 昨天 买 的 ， 今天 就...  | 1      | ['昨天', ' ', '买', '... |\n",
      "| 为什么 我 手机 今天 到... | 2      | ['为什么', ' ', '我',... |\n",
      "| 特意 用 了 一段时间 才... | 0      | ['特意', ' ', '用', '... |\n",
      "| 早就 想 买 了 ， 送人...  | 0      | ['早就', ' ', '想', '... |\n",
      "| 缺点 就是 手机 送来 的... | 2      | ['缺点', ' ', '就是',... |\n",
      "| 送货 速度 快 ， 手机 ...  | 0      | ['送货', ' ', '速度',... |\n",
      "| 晚上 下单 第二天 上午...  | 0      | ['晚上', ' ', '下单',... |\n",
      "| 不错 打白条 来 的 为什... | 0      | ['不错', ' ', '打白条... |\n",
      "| 物流 这些 还 可以 ， ...  | 2      | ['物流', ' ', '这些',... |\n",
      "| 检查 了 是 正品 挺 好...  | 0      | ['检查', ' ', '了', '... |\n",
      "| 手机 是 用 过 的 打 客... | 1      | ['手机', ' ', '是', '... |\n",
      "| 收到 手机 使用 了 有 ...  | 0      | ['收到', ' ', '手机',... |\n",
      "| 快递 很 给力 配送员 态... | 1      | ['快递', ' ', '很', '... |\n",
      "| 还好 吧 ， 就是 信号 ...  | 2      | ['还好', ' ', '吧', '... |\n",
      "| 手机 是 没 问题 ， 可...  | 1      | ['手机', ' ', '是', '... |\n",
      "| 挺 好 的 ， 就是 太贵...  | 2      | ['挺', ' ', '好', ' '... |\n",
      "| 完美 的 一次 购物 ， ...  | 0      | ['完美', ' ', '的', '... |\n",
      "| 信号 不太好 ， 心塞 。... | 2      | ['信号', ' ', '不太好... |\n",
      "| 先不说 手机 刚回来 第...  | 1      | ['先不说', ' ', '手机... |\n",
      "| 手机 还 行 ， 就是 手...  | 1      | ['手机', ' ', '还', '... |\n",
      "| 买 了 之后 没过几天 就... | 1      | ['买', ' ', '了', ' '... |\n",
      "| 刚 买回来 ， 很 好 用...  | 0      | ['刚', ' ', '买回来',... |\n",
      "| 屏幕 有 一个点 ， 小 ...  | 2      | ['屏幕', ' ', '有', '... |\n",
      "| ...                       | ...    | ...                      |\n",
      "+---------------------------+--------+--------------------------+\n"
     ]
    }
   ],
   "source": [
    "dataset.apply(lambda ins:get_tokenized(ins['raw_words']), new_field_name='words', is_input=True)\n",
    "print(dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+----------------------+--------+----------------------+---------+\n",
      "| raw_words            | target | words                | seq_len |\n",
      "+----------------------+--------+----------------------+---------+\n",
      "| 这个 系统 很 好 ...  | 0      | ['这个', ' ', '系... | 21      |\n",
      "| 我 说 你 信 吗 。... | 2      | ['我', ' ', '说'...  | 11      |\n",
      "| 快递 特变 慢 ， ...  | 2      | ['快递', ' ', '特... | 33      |\n",
      "| 用 了 一个 来 月...  | 1      | ['用', ' ', '了'...  | 124     |\n",
      "| 早上 买 的 4888 ...  | 1      | ['早上', ' ', '买... | 15      |\n",
      "| 一 被 蹂躏 了 N ...  | 1      | ['一', ' ', '被'...  | 35      |\n",
      "| 我 很 满意 手机 ...  | 0      | ['我', ' ', '很'...  | 11      |\n",
      "| 没有 想象 中 的 ...  | 2      | ['没有', ' ', '想... | 27      |\n",
      "| 昨天 买 的 ， 今...  | 1      | ['昨天', ' ', '买... | 25      |\n",
      "| 为什么 我 手机 今... | 2      | ['为什么', ' ', ...  | 15      |\n",
      "| 特意 用 了 一段时... | 0      | ['特意', ' ', '用... | 45      |\n",
      "| 早就 想 买 了 ，...  | 0      | ['早就', ' ', '想... | 29      |\n",
      "| 缺点 就是 手机 送... | 2      | ['缺点', ' ', '就... | 18      |\n",
      "| 送货 速度 快 ， ...  | 0      | ['送货', ' ', '速... | 27      |\n",
      "| 晚上 下单 第二天...  | 0      | ['晚上', ' ', '下... | 19      |\n",
      "| 不错 打白条 来 的... | 0      | ['不错', ' ', '打... | 27      |\n",
      "| 物流 这些 还 可以... | 2      | ['物流', ' ', '这... | 69      |\n",
      "| 检查 了 是 正品 ...  | 0      | ['检查', ' ', '了... | 35      |\n",
      "| 手机 是 用 过 的...  | 1      | ['手机', ' ', '是... | 31      |\n",
      "| 收到 手机 使用 了... | 0      | ['收到', ' ', '手... | 96      |\n",
      "| 快递 很 给力 配送... | 1      | ['快递', ' ', '很... | 40      |\n",
      "| 还好 吧 ， 就是 ...  | 2      | ['还好', ' ', '吧... | 35      |\n",
      "| 手机 是 没 问题 ...  | 1      | ['手机', ' ', '是... | 35      |\n",
      "| 挺 好 的 ， 就是...  | 2      | ['挺', ' ', '好'...  | 25      |\n",
      "| 完美 的 一次 购物... | 0      | ['完美', ' ', '的... | 25      |\n",
      "| 信号 不太好 ， 心... | 2      | ['信号', ' ', '不... | 9       |\n",
      "| 先不说 手机 刚回...  | 1      | ['先不说', ' ', ...  | 86      |\n",
      "| 手机 还 行 ， 就...  | 1      | ['手机', ' ', '还... | 41      |\n",
      "| 买 了 之后 没过几... | 1      | ['买', ' ', '了'...  | 44      |\n",
      "| 刚 买回来 ， 很 ...  | 0      | ['刚', ' ', '买回... | 11      |\n",
      "| 屏幕 有 一个点 ，... | 2      | ['屏幕', ' ', '有... | 17      |\n",
      "| ...                  | ...    | ...                  | ...     |\n",
      "+----------------------+--------+----------------------+---------+\n"
     ]
    }
   ],
   "source": [
    "dataset.apply(lambda ins: len(ins['words']) ,new_field_name='seq_len', is_input=True)\n",
    "\n",
    "print(dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+----------------------+--------+----------------------+---------+\n",
      "| raw_words            | target | words                | seq_len |\n",
      "+----------------------+--------+----------------------+---------+\n",
      "| 这个 系统 很 好 ...  | 0      | ['这个', ' ', '系... | 21      |\n",
      "| 我 说 你 信 吗 。... | 2      | ['我', ' ', '说'...  | 11      |\n",
      "| 快递 特变 慢 ， ...  | 2      | ['快递', ' ', '特... | 33      |\n",
      "| 用 了 一个 来 月...  | 1      | ['用', ' ', '了'...  | 124     |\n",
      "| 早上 买 的 4888 ...  | 1      | ['早上', ' ', '买... | 15      |\n",
      "| 一 被 蹂躏 了 N ...  | 1      | ['一', ' ', '被'...  | 35      |\n",
      "| 我 很 满意 手机 ...  | 0      | ['我', ' ', '很'...  | 11      |\n",
      "| 没有 想象 中 的 ...  | 2      | ['没有', ' ', '想... | 27      |\n",
      "| 昨天 买 的 ， 今...  | 1      | ['昨天', ' ', '买... | 25      |\n",
      "| 为什么 我 手机 今... | 2      | ['为什么', ' ', ...  | 15      |\n",
      "| 特意 用 了 一段时... | 0      | ['特意', ' ', '用... | 45      |\n",
      "| 早就 想 买 了 ，...  | 0      | ['早就', ' ', '想... | 29      |\n",
      "| 缺点 就是 手机 送... | 2      | ['缺点', ' ', '就... | 18      |\n",
      "| 送货 速度 快 ， ...  | 0      | ['送货', ' ', '速... | 27      |\n",
      "| 晚上 下单 第二天...  | 0      | ['晚上', ' ', '下... | 19      |\n",
      "| 不错 打白条 来 的... | 0      | ['不错', ' ', '打... | 27      |\n",
      "| 物流 这些 还 可以... | 2      | ['物流', ' ', '这... | 69      |\n",
      "| 检查 了 是 正品 ...  | 0      | ['检查', ' ', '了... | 35      |\n",
      "| 手机 是 用 过 的...  | 1      | ['手机', ' ', '是... | 31      |\n",
      "| 收到 手机 使用 了... | 0      | ['收到', ' ', '手... | 96      |\n",
      "| 快递 很 给力 配送... | 1      | ['快递', ' ', '很... | 40      |\n",
      "| 还好 吧 ， 就是 ...  | 2      | ['还好', ' ', '吧... | 35      |\n",
      "| 手机 是 没 问题 ...  | 1      | ['手机', ' ', '是... | 35      |\n",
      "| 挺 好 的 ， 就是...  | 2      | ['挺', ' ', '好'...  | 25      |\n",
      "| 完美 的 一次 购物... | 0      | ['完美', ' ', '的... | 25      |\n",
      "| 信号 不太好 ， 心... | 2      | ['信号', ' ', '不... | 9       |\n",
      "| 先不说 手机 刚回...  | 1      | ['先不说', ' ', ...  | 86      |\n",
      "| 手机 还 行 ， 就...  | 1      | ['手机', ' ', '还... | 41      |\n",
      "| 买 了 之后 没过几... | 1      | ['买', ' ', '了'...  | 44      |\n",
      "| 刚 买回来 ， 很 ...  | 0      | ['刚', ' ', '买回... | 11      |\n",
      "| 屏幕 有 一个点 ，... | 2      | ['屏幕', ' ', '有... | 17      |\n",
      "| ...                  | ...    | ...                  | ...     |\n",
      "+----------------------+--------+----------------------+---------+\n"
     ]
    }
   ],
   "source": [
    "dataset.apply(lambda x: int(x['target']), new_field_name='target', is_target=True)\n",
    "print(dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+---------------------------+--------------------------+---------+\n",
      "| raw_words                 | words                    | seq_len |\n",
      "+---------------------------+--------------------------+---------+\n",
      "| 有 三百 的 优惠卷 买 ...  | ['有', ' ', '三百', '... | 25      |\n",
      "| 充电 那么 慢 ， 比 正...  | ['充电', ' ', '那么',... | 49      |\n",
      "| 很 不错 。 特地 用 了...  | ['很', ' ', '不错', '... | 27      |\n",
      "| 用 了 一段时间 还 可以... | ['用', ' ', '了', ' '... | 25      |\n",
      "| 说 实在 的 手机 不如 ...  | ['说', ' ', '实在', '... | 23      |\n",
      "| 最 差劲 的 一次 购物 ...  | ['最', ' ', '差劲', '... | 209     |\n",
      "| 刚刚 收到 手机 ， 12 ...  | ['刚刚', ' ', '收到',... | 104     |\n",
      "| 还 可以 快递 也 给力      | ['还', ' ', '可以', '... | 9       |\n",
      "| 好 ， 售后服务 到位 ，... | ['好', ' ', '，', ' '... | 25      |\n",
      "| 速度 很快 ， 第二天 早... | ['速度', ' ', '很快',... | 77      |\n",
      "| 手机 有 多处 刮伤         | ['手机', ' ', '有', '... | 7       |\n",
      "| 今天 刚到 ， 感觉 也 ...  | ['今天', ' ', '刚到',... | 42      |\n",
      "| 手机 阴阳屏 ， 退货 了... | ['手机', ' ', '阴阳',... | 28      |\n",
      "| 送人 的 应该 还 可以 ...  | ['送人', ' ', '的', '... | 11      |\n",
      "| 拆箱 防盗 标签 是 错位... | ['拆箱', ' ', '防盗',... | 71      |\n",
      "| 外 包装盒 里 很大 没有... | ['外', ' ', '包装盒',... | 209     |\n",
      "| 快递 小哥 挺快 就是 手... | ['快递', ' ', '小哥',... | 41      |\n",
      "| 京东 真的 是 打着 618...  | ['京东', ' ', '真的',... | 75      |\n",
      "| 在 京东 买 了 不 知道...  | ['在', ' ', '京东', '... | 71      |\n",
      "| 耗电量 太快 了 ， 电池... | ['耗电量', ' ', '太快... | 15      |\n",
      "| 其他 还好 ， 就是 电 ...  | ['其他', ' ', '还好',... | 13      |\n",
      "| 反应 慢 ， 拍照 不好 ...  | ['反应', ' ', '慢', '... | 15      |\n",
      "| 怎么 说 呢 ， 可能 也...  | ['怎么', ' ', '说', '... | 241     |\n",
      "| 快递 发货 快 一天 到 ...  | ['快递', ' ', '发货',... | 19      |\n",
      "| 非常 好用 ， 京东 就是... | ['非常', ' ', '好用',... | 23      |\n",
      "| 目前 正在 使用 ， 暂时... | ['目前', ' ', '正在',... | 83      |\n",
      "| 耳机 是 坏 的 ！          | ['耳机', ' ', '是', '... | 9       |\n",
      "| 买 了 几天 降价 200 差... | ['买', ' ', '了', ' '... | 11      |\n",
      "| 还 不错 ， 就是 有 就...  | ['还', ' ', '不错', '... | 29      |\n",
      "| 4156 135 4 6515615341...  | ['4156', ' ', '135', ... | 15      |\n",
      "| 无语 ， 买 到 三星 芯...  | ['无', '语', ' ', '，... | 12      |\n",
      "| ...                       | ...                      | ...     |\n",
      "+---------------------------+--------------------------+---------+\n"
     ]
    }
   ],
   "source": [
    "#testset.apply(lambda ins: list(chain.from_iterable(get_tokenized(ins['raw_words']))), new_field_name='words', is_input=True)\n",
    "\n",
    "testset.apply(lambda ins: get_tokenized(ins['raw_words']), new_field_name='words', is_input=True)\n",
    "\n",
    "testset.apply(lambda ins: len(ins['words']) ,new_field_name='seq_len',is_input=True)\n",
    "print(testset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "###\n",
    "\n",
    "from fastNLP import Vocabulary\n",
    "\n",
    "#将DataSet按照ratio的比例拆分，返回两个DataSet\n",
    "\n",
    "#ratio (float) -- 0<ratio<1, 返回的第一个DataSet拥有 (1-ratio) 这么多数据，第二个DataSet拥有`ratio`这么多数据\n",
    "\n",
    "train_data, dev_data = dataset.split(0.1, shuffle=False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+----------------------+--------+----------------------+---------+\n",
      "| raw_words            | target | words                | seq_len |\n",
      "+----------------------+--------+----------------------+---------+\n",
      "| 屏幕 小 了 ， 苹...  | 1      | ['屏幕', ' ', '小... | 71      |\n",
      "| 第 n 次 在 京东 ...  | 0      | ['第', ' ', 'n',...  | 47      |\n",
      "| 说好 的 免息 ， ...  | 2      | ['说好', ' ', '的... | 35      |\n",
      "| 可惜 不是 台积 电... | 1      | ['可惜', ' ', '不... | 15      |\n",
      "| 包装 太 随便         | 2      | ['包装', ' ', '太... | 5       |\n",
      "| 店大欺客 ， 4888...  | 1      | ['店大欺客', ' '...  | 63      |\n",
      "| 正品 ， 京东 真是... | 0      | ['正品', ' ', '，... | 17      |\n",
      "| 机身 有 伤痕 ， ...  | 2      | ['机身', ' ', '有... | 21      |\n",
      "| 外形 很漂亮 ， 颜... | 0      | ['外形', ' ', '很... | 27      |\n",
      "| 手机 正在 用 。 ...  | 0      | ['手机', ' ', '正... | 31      |\n",
      "| 在 京东 买过 好几... | 0      | ['在', ' ', '京东... | 27      |\n",
      "| 第一次 网购 苹果...  | 0      | ['第一次', ' ', ...  | 114     |\n",
      "| 怀疑 是 被 翻新 ...  | 1      | ['怀疑', ' ', '是... | 93      |\n",
      "| 算 还 不错 啦 ，...  | 0      | ['算', ' ', '还'...  | 105     |\n",
      "| iphone 6s 收到 了... | 1      | ['iphone', ' ', ...  | 115     |\n",
      "| 买 完 就 降价 10...  | 0      | ['买', ' ', '完'...  | 9       |\n",
      "| 手机 是 原装 的 ...  | 2      | ['手机', ' ', '是... | 39      |\n",
      "| 支持 京东 ， 物流... | 0      | ['支持', ' ', '京... | 11      |\n",
      "| 苹果 手机 很 不错... | 2      | ['苹果', ' ', '手... | 17      |\n",
      "| 吐槽 语句 发不出...  | 2      | ['吐槽', ' ', '语... | 17      |\n",
      "| 还 可以 吧 ， 感...  | 2      | ['还', ' ', '可以... | 41      |\n",
      "| 为什么 用 着 用 ...  | 2      | ['为什么', ' ', ...  | 21      |\n",
      "| 用 了 两天 电池电... | 2      | ['用', ' ', '了'...  | 17      |\n",
      "| 4 月 15 号 才 到...  | 1      | ['4', ' ', '月',...  | 47      |\n",
      "| 耳机 不好 使 耳机... | 2      | ['耳机', ' ', '不... | 25      |\n",
      "| 速度 还 行 就是 ...  | 2      | ['速度', ' ', '还... | 15      |\n",
      "| 京东 派送 就是 赞... | 0      | ['京东', ' ', '派... | 79      |\n",
      "| 比较 了 几家 还是... | 2      | ['比较', ' ', '了... | 199     |\n",
      "| 有时候 会 莫名 的... | 2      | ['有时候', ' ', ...  | 15      |\n",
      "| 是 正品 ， 希望 ...  | 0      | ['是', ' ', '正品... | 25      |\n",
      "| 刚买 就 降价 了 ...  | 2      | ['刚买', ' ', '就... | 21      |\n",
      "| ...                  | ...    | ...                  | ...     |\n",
      "+----------------------+--------+----------------------+---------+\n"
     ]
    }
   ],
   "source": [
    "print(train_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "11267 1251 3130\n"
     ]
    }
   ],
   "source": [
    "print(len(train_data),len(dev_data),len(testset))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Vocabulary(['这个', ' ', '系统', '很', '好']...)"
      ]
     },
     "execution_count": 83,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vocab = Vocabulary(min_freq=2).from_dataset(dataset, field_name='words')\n",
    "\n",
    "vocab.index_dataset(train_data, dev_data, testset, field_name='words', new_field_name='words')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Found 931 out of 6744 words in the pre-training embedding.\n"
     ]
    }
   ],
   "source": [
    "from fastNLP.embeddings import StaticEmbedding,StackEmbedding\n",
    "\n",
    "fastnlp_embed = StaticEmbedding(vocab, model_dir_or_name='cn-char-fastnlp-100d',min_freq=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "STSeqCls(\n",
      "  (enc): StarTransEnc(\n",
      "    (embedding): StaticEmbedding(\n",
      "      (dropout_layer): Dropout(p=0)\n",
      "      (embedding): Embedding(6744, 100, padding_idx=0)\n",
      "    )\n",
      "    (emb_fc): Linear(in_features=100, out_features=300, bias=True)\n",
      "    (encoder): StarTransformer(\n",
      "      (norm): ModuleList(\n",
      "        (0): LayerNorm(torch.Size([300]), eps=1e-06, elementwise_affine=True)\n",
      "        (1): LayerNorm(torch.Size([300]), eps=1e-06, elementwise_affine=True)\n",
      "        (2): LayerNorm(torch.Size([300]), eps=1e-06, elementwise_affine=True)\n",
      "        (3): LayerNorm(torch.Size([300]), eps=1e-06, elementwise_affine=True)\n",
      "      )\n",
      "      (emb_drop): Dropout(p=0.1)\n",
      "      (ring_att): ModuleList(\n",
      "        (0): _MSA1(\n",
      "          (WQ): Conv2d(300, 256, kernel_size=(1, 1), stride=(1, 1))\n",
      "          (WK): Conv2d(300, 256, kernel_size=(1, 1), stride=(1, 1))\n",
      "          (WV): Conv2d(300, 256, kernel_size=(1, 1), stride=(1, 1))\n",
      "          (WO): Conv2d(256, 300, kernel_size=(1, 1), stride=(1, 1))\n",
      "          (drop): Dropout(p=0.0)\n",
      "        )\n",
      "        (1): _MSA1(\n",
      "          (WQ): Conv2d(300, 256, kernel_size=(1, 1), stride=(1, 1))\n",
      "          (WK): Conv2d(300, 256, kernel_size=(1, 1), stride=(1, 1))\n",
      "          (WV): Conv2d(300, 256, kernel_size=(1, 1), stride=(1, 1))\n",
      "          (WO): Conv2d(256, 300, kernel_size=(1, 1), stride=(1, 1))\n",
      "          (drop): Dropout(p=0.0)\n",
      "        )\n",
      "        (2): _MSA1(\n",
      "          (WQ): Conv2d(300, 256, kernel_size=(1, 1), stride=(1, 1))\n",
      "          (WK): Conv2d(300, 256, kernel_size=(1, 1), stride=(1, 1))\n",
      "          (WV): Conv2d(300, 256, kernel_size=(1, 1), stride=(1, 1))\n",
      "          (WO): Conv2d(256, 300, kernel_size=(1, 1), stride=(1, 1))\n",
      "          (drop): Dropout(p=0.0)\n",
      "        )\n",
      "        (3): _MSA1(\n",
      "          (WQ): Conv2d(300, 256, kernel_size=(1, 1), stride=(1, 1))\n",
      "          (WK): Conv2d(300, 256, kernel_size=(1, 1), stride=(1, 1))\n",
      "          (WV): Conv2d(300, 256, kernel_size=(1, 1), stride=(1, 1))\n",
      "          (WO): Conv2d(256, 300, kernel_size=(1, 1), stride=(1, 1))\n",
      "          (drop): Dropout(p=0.0)\n",
      "        )\n",
      "      )\n",
      "      (star_att): ModuleList(\n",
      "        (0): _MSA2(\n",
      "          (WQ): Conv2d(300, 256, kernel_size=(1, 1), stride=(1, 1))\n",
      "          (WK): Conv2d(300, 256, kernel_size=(1, 1), stride=(1, 1))\n",
      "          (WV): Conv2d(300, 256, kernel_size=(1, 1), stride=(1, 1))\n",
      "          (WO): Conv2d(256, 300, kernel_size=(1, 1), stride=(1, 1))\n",
      "          (drop): Dropout(p=0.0)\n",
      "        )\n",
      "        (1): _MSA2(\n",
      "          (WQ): Conv2d(300, 256, kernel_size=(1, 1), stride=(1, 1))\n",
      "          (WK): Conv2d(300, 256, kernel_size=(1, 1), stride=(1, 1))\n",
      "          (WV): Conv2d(300, 256, kernel_size=(1, 1), stride=(1, 1))\n",
      "          (WO): Conv2d(256, 300, kernel_size=(1, 1), stride=(1, 1))\n",
      "          (drop): Dropout(p=0.0)\n",
      "        )\n",
      "        (2): _MSA2(\n",
      "          (WQ): Conv2d(300, 256, kernel_size=(1, 1), stride=(1, 1))\n",
      "          (WK): Conv2d(300, 256, kernel_size=(1, 1), stride=(1, 1))\n",
      "          (WV): Conv2d(300, 256, kernel_size=(1, 1), stride=(1, 1))\n",
      "          (WO): Conv2d(256, 300, kernel_size=(1, 1), stride=(1, 1))\n",
      "          (drop): Dropout(p=0.0)\n",
      "        )\n",
      "        (3): _MSA2(\n",
      "          (WQ): Conv2d(300, 256, kernel_size=(1, 1), stride=(1, 1))\n",
      "          (WK): Conv2d(300, 256, kernel_size=(1, 1), stride=(1, 1))\n",
      "          (WV): Conv2d(300, 256, kernel_size=(1, 1), stride=(1, 1))\n",
      "          (WO): Conv2d(256, 300, kernel_size=(1, 1), stride=(1, 1))\n",
      "          (drop): Dropout(p=0.0)\n",
      "        )\n",
      "      )\n",
      "      (pos_emb): Embedding(512, 300)\n",
      "    )\n",
      "  )\n",
      "  (cls): _Cls(\n",
      "    (fc): Sequential(\n",
      "      (0): Linear(in_features=300, out_features=600, bias=True)\n",
      "      (1): LeakyReLU(negative_slope=0.01)\n",
      "      (2): Dropout(p=0.1)\n",
      "      (3): Linear(in_features=600, out_features=3, bias=True)\n",
      "    )\n",
      "  )\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "# # 不知道咋用\n",
    "# from fastNLP.models import ESIM\n",
    "\n",
    "# # 这个不照\n",
    "# model_scim=ESIM(fastnlp_embed,num_labels=2, dropout_rate=0.3, dropout_embed=0.1)\n",
    "\n",
    "# print(model_scim)\n",
    "\n",
    "\n",
    "from fastNLP.models.star_transformer import STSeqCls\n",
    "\n",
    "# 这个不照\n",
    "model_stsc=STSeqCls(fastnlp_embed,num_cls=3, hidden_size=300\n",
    "                    , num_layers=4, num_head=8\n",
    "                    , head_dim=32, max_len=512, cls_hidden_size=600, emb_dropout=0.1, dropout=0.1)\n",
    "\n",
    "# ESIM(fastnlp_embed,num_labels=2, dropout_rate=0.3, dropout_embed=0.1)\n",
    "\n",
    "print(model_stsc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNNText(\n",
      "  (embed): Embedding(\n",
      "    (embed): StaticEmbedding(\n",
      "      (dropout_layer): Dropout(p=0)\n",
      "      (embedding): Embedding(6744, 100, padding_idx=0)\n",
      "    )\n",
      "    (dropout): Dropout(p=0.0)\n",
      "  )\n",
      "  (conv_pool): ConvMaxpool(\n",
      "    (convs): ModuleList(\n",
      "      (0): Conv1d(100, 30, kernel_size=(1,), stride=(1,), bias=False)\n",
      "      (1): Conv1d(100, 40, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)\n",
      "      (2): Conv1d(100, 50, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n",
      "    )\n",
      "  )\n",
      "  (dropout): Dropout(p=0.1)\n",
      "  (fc): Linear(in_features=120, out_features=3, bias=True)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "from fastNLP.models import CNNText\n",
    "\n",
    "model_CNN = CNNText(fastnlp_embed, num_classes=3,dropout=0.1)\n",
    "\n",
    "print(model_CNN)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "input fields after batch(if batch size is 2):\n",
      "\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 71]) \n",
      "\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
      "target fields after batch(if batch size is 2):\n",
      "\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
      "\n",
      "training epochs started 2020-07-02-18-19-28-564899\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "15f6553a598346519194face5b515bd1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=3530.0), HTML(value='')), layout=Layout(d…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric,BCELoss\n",
    "\n",
    "trainer_CNN = Trainer(model=model_CNN, train_data=train_data, dev_data=dev_data,loss=CrossEntropyLoss(), metrics=AccuracyMetric())\n",
    "\n",
    "trainer_CNN.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
